pacman::p_load(
  tidyverse,
  yardstick,
  tictoc,
  dtplyr,
  ggthemes,
  ggrepel,
  fst,
  dtplyr,
  lubridate,
  cowplot,
  knitr,
  kableExtra
)

if (LOOP != "combined") {
  
  file_name_tradeoff <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX,
                               "/plots_", SUFFIX_FITTED_, "/data_tradeoff.fst")
}

if (LOOP == "combined") {
  
  file_name_tradeoff <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                               "/plots_combined/data_tradeoff.fst")
}

df_tradeoffs_metrics <- read_fst(file_name_tradeoff)

Prepare_Constraints <- function(df_metrics, lambda) {

  profit_baseline_ <- df_metrics %>% 
    filter(thresh_type == "single", model_type == "logistic", loss_profit_ratio == lambda) %>% 
    pull(profit)
  
  df_metrics %>% 
    filter(loss_profit_ratio == lambda) %>% 
    mutate(
      constraint = case_when(
        thresh_type == "min_cost" ~ "Max Profit",
        thresh_type == "single" ~ "Blind",
        thresh_type == "tpr_weak" ~ "Weak",
        thresh_type == "tpr_med" ~ "Medium",
        thresh_type == "tpr_strong" ~ "Strong"
      ),
      constraint = factor(constraint, levels = c("Strong", "Medium", "Weak", "Blind", "Max Profit"))
    ) %>%
    mutate(
      `Model` = case_when(
        model_type == "xgb" ~ "XGB",
        model_type == "logistic" ~ "Logistic",
        model_type == "riskscore" ~ "Riskscore"
      ),
      Model = factor(Model, levels = c("Riskscore", "Logistic", "XGB"))
    ) %>% 
    filter(thresh_type != "min_cost") %>% 
    filter(Model != "Riskscore") %>% 
    mutate(
      profit = (profit / profit_baseline_) * 100
    ) 
}

Make_Base_Tradeoff_Plot <- function(df_lambda_plotting) {
  
  ggplot(df_lambda_plotting, aes(profit, tpr_diff, color = Model, group = Model)) + 
    geom_point(size = 5, key_glyph = "path") + 
    geom_line(size = 1.5, orientation = "y", key_glyph = "path") +
    geom_label(aes(label = constraint), family = "Avenir", vjust = 1, color = "black") +
    scale_color_ptol() +
    scale_x_continuous(limits = c(98.5, 102)) +
    # scale_y_continuous(limits = c(-0.15, 0.005)) +
    # scale_x_continuous(labels = scales::percent, limits = c(0.98, 1.01), breaks = c(0.85, 0.9, 0.95, 1)) +
    theme_minimal() + 
    theme(
      panel.grid.minor = element_blank(),
      text = element_text(family = "Avenir", size = 20),
      legend.position = "bottom"
    ) +
    xlab("Profit") + ylab("ΔTPR")
}

Save_Tradeoff_Panel <- function(p_tradeoff_plot, file_name, plot_width, plot_height, plot_dpi) {
  
  if (LOOP != "combined") {
    
    path_plot_out <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/plots_", 
                            SUFFIX_FITTED_, "/", file_name)
  }
  
  if (LOOP == "combined") {
    
    path_plot_out <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                            "/plots_combined/", file_name)
  }
  #print(p_tradeoff_plot)
  ggsave(p_tradeoff_plot, filename = paste0(path_plot_out),
         width = plot_width, height = plot_height, 
         units = "in", dpi = plot_dpi
  )
  
}

df_lambda <- Prepare_Constraints(df_tradeoffs_metrics, lambda = 5)
p_lambda <- Make_Base_Tradeoff_Plot(df_lambda)


if (LOOP == "combined") {

  write_csv(df_lambda,
            str_glue("../../data/pipeline_outputs/{SPECIAL_SUFFIX}/plots_combined/tradeoff_no_arrows.csv"))

}

Save_Tradeoff_Panel(p_lambda, "tpr_profit_rs_logistic_no_arrows_lambda.png",
                      7, 7, 1000)

# Arrows Plot -------------------------------------------------------------

df_arrow <- df_lambda %>% 
  #filter(thresh_type  "single") %>% 
  print()

x_ <- df_arrow %>% 
  filter(constraint == "Blind", Model == "Logistic") %>% 
  pull(profit)

xend_ <- df_arrow %>% 
  filter(constraint == "Blind", Model == "XGB") %>% 
  pull(profit)

xdist_ <- xend_ - x_
y_ <- df_arrow %>% 
  filter(constraint == "Blind", Model == "Logistic") %>% 
  pull(tpr_diff)

yend_ <- df_arrow %>% 
  filter(constraint == "Blind", Model == "XGB") %>% 
  pull(tpr_diff)

ydist_ <- yend_ - y_
factor_ <- 0.15
xmid_ <- x_ + 0.5 * xdist_
ymid_ <- y_ + 0.5 * ydist_

x1end_ <- df_arrow %>% 
  filter(constraint == "Strong", Model == "XGB") %>% 
  pull(profit)

x1dist_ <- x1end_ - x_

y1end_ <- df_arrow %>% 
  filter(constraint == "Strong", Model == "XGB") %>% 
  pull(tpr_diff)

y1dist_ <- y1end_ - y_
factor1_ <- 0.07
x1mid_ <- x_ + 0.5 * x1dist_
y1mid_ <- y_ + 0.5 * y1dist_

# New Model Arrows --------------------------------------------------------

p_new_model_arrows_lambda <- p_lambda +
  annotate("segment", size = 2, color = "black", linetype = 2,
           x = x_  + xdist_ * factor_, y = y_ + ydist_ * factor_,
           xend = xend_ - xdist_ * factor_, yend = yend_ - ydist_ * factor_,
           linejoin = "mitre",
           arrow = arrow(type = "closed", length = unit(0.01, "npc"))) +
  annotate("label", x = xmid_, y = ymid_, label = "New model only", family = "Avenir") 

Save_Tradeoff_Panel(p_new_model_arrows_lambda, "tpr_profit_rs_logistic_new_model_arrows.png",
                      7, 7, 1000)

# New Model and Constraint Arrows -----------------------------------------

p_new_model_and_constraint_arrows <- p_new_model_arrows_lambda +
  annotate("segment", size = 2, color = "black", linetype = 2,
           x = x_  + x1dist_ * factor1_, y = y_ + y1dist_ * factor1_,
           xend = x1end_ - x1dist_ * factor1_, yend = y1end_ - y1dist_ * factor1_,
           linejoin = "mitre",
           arrow = arrow(type = "closed", length = unit(0.01, "npc"))) +
  annotate("label", x = x1mid_, y = y1mid_, label = "New model\n+ constraint", family = "Avenir")

df_new_model_and_arrow_constraints <- p_new_model_and_constraint_arrows %>% 
  pluck("data")

lambda_ <- df_new_model_and_arrow_constraints %>% 
  distinct(loss_profit_ratio) %>% 
  pull(loss_profit_ratio)

path_plot_out <- switch (LOOP,
                         "windows" =  paste0("../../data/pipeline_outputs/",
                                             SPECIAL_SUFFIX, "/plots_",
                                             SUFFIX_FITTED_, "/underlying_new_model_and_arrow_data_",
                                             lambda_, ".csv"),
                         "combined" =  paste0("../../data/pipeline_outputs/",
                                              SPECIAL_SUFFIX, "/underlying_new_model_and_arrow_data_",
                                              lambda_, ".csv")
)

write_csv(df_new_model_and_arrow_constraints, path_plot_out)


Save_Tradeoff_Panel(p_new_model_and_constraint_arrows, 
                      "tpr_profit_rs_logistic_xgb_model_constraint_arrows.png",
                      7, 7, 1000)

# Lambda for Cowplot ------------------------------------------------------

p_smallest_lambda <- Prepare_Constraints(df_tradeoffs_metrics, lambda = 3) %>% 
  Make_Base_Tradeoff_Plot(.) +
  xlim(98, 102.5) +
  ylim(-.12, 0.01)

p_third_smallest_lambda <- Prepare_Constraints(df_tradeoffs_metrics, lambda = 5) %>% 
  Make_Base_Tradeoff_Plot(.) +
  xlim(98, 102.5) +
  ylim(-.12, 0.01)

p_fourth_smallest_lambda <-  Prepare_Constraints(df_tradeoffs_metrics, lambda = 7) %>% 
  Make_Base_Tradeoff_Plot(.) +
  xlim(98, 102.5) +
  ylim(-.12, 0.01)



l_tradeoff_panels_underlying_data <- list(
  p_smallest_lambda$data,
  p_third_smallest_lambda$data,
  p_fourth_smallest_lambda$data
)

Save_Underlying_Panels_Data <- function(df_underlying_data, LOOP) {
  
  lambda <- df_underlying_data %>% 
    distinct(loss_profit_ratio) %>% 
    pull(loss_profit_ratio)
  
  if (LOOP == "windows") {
    
    write_csv(df_underlying_data, paste0("../../data/pipeline_outputs/",
                                         SPECIAL_SUFFIX, "/plots_",
                                         SUFFIX_FITTED_, "/underlying_panels_data_",
                                         lambda, ".csv"))
    
  }
  
  if (LOOP == "combined") {
    
    write_csv(df_underlying_data, paste0("../../data/pipeline_outputs/",
                                         SPECIAL_SUFFIX, "/plots_combined/underlying_panels_data_",
                                         lambda, ".csv"))
    
  }
  
  
}

walk(l_tradeoff_panels_underlying_data, ~ Save_Underlying_Panels_Data(., LOOP))


p_lambda_cowplot <- plot_grid(p_fourth_smallest_lambda, NULL, p_third_smallest_lambda, p_smallest_lambda, 
                              labels = c("A: λ=7", "", "B: λ=5 (Main Specification)", "C: λ=3"), 
                              label_fontfamily = "Avenir", label_x = .4, hjust = c(-.5, -.5, .05, -.5),
                              nrow = 2)


Save_Tradeoff_Panel(p_lambda_cowplot, "tradeoff_panels.png", 14, 14, 500)



# Winners and Losers ------------------------------------------------------

if (LOOP != "combined") {
  
  file_name_win_lose <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/",
                               "plots_", SUFFIX_FITTED_, "/data_win_lose.fst")
}


if (LOOP == "combined") {
  
  file_name_win_lose <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                               "/plots_combined/data_win_lose.fst")
}

df_win_lose_metrics <- read_fst(file_name_win_lose)

unique(df_win_lose_metrics$loss_profit_ratio)

df_gg <- df_win_lose_metrics %>% 
  filter(loss_profit_ratio == 5) %>% 
  mutate(
    model_thresh = paste0(model_type, "_", thresh_type)
  ) %>% 
  filter(thresh_type != "min_cost") %>% 
  # filter(
  #   model_thresh %in% c("logistic_single", "xgb_tpr_weak", "xgb_tpr_med", "xgb_tpr_strong")
  # ) %>%
  #filter(model_type == "xgb") %>% 
  arrange(model_type, d_Income_Level, metric, desc(constraint)) %>% 
  group_by(d_Income_Level, metric) %>% 
  #filter(model_type != "riskscore") %>% 
  filter(model_type == "xgb") %>% 
  mutate(
    label_level = paste0(round(value * 100, digits = 1), "%"),
    label_level = ifelse(constraint %in% c("Blind", "Strong"), label_level, NA),
    
    label_diff = value - lag(value, 1),
    label_diff = round(label_diff * 100, 2),
    label_diff = ifelse(label_diff > 0, paste0("+", label_diff, "pp"), paste0(label_diff, "pp")),
    group = paste0(metric, "_", model_type),
    d_Income_Level = str_to_title(d_Income_Level),
    Metric = case_when(
      metric == "fn" ~ "F. Neg.",
      metric == "tn" ~ "T. Neg.",
      metric == "tp" ~ "T. Pos.",
      metric == "fp" ~ "F. Pos."
    ),
    Metric = factor(Metric, levels = c("T. Pos.", "T. Neg.", "F. Neg.", "F. Pos."))
  ) %>% 
  mutate(
    facet_LMI = factor(d_Income_Level, levels = c(1, 0), labels = c("Non-LMI", "LMI"))
  ) %>% 
  glimpse()

p_winners_losers <- ggplot(df_gg, aes(fct_rev(constraint), value, color = Metric, group = group)) +
  facet_wrap(~ facet_LMI) +
  geom_point(size = 3, key_glyph = "path") + 
  geom_line(size = 1.5, key_glyph = "path") + 
  scale_color_ptol() +
  geom_label_repel(aes(label = label_level), size = 3, key_glyph = "path", family = "Avenir") +
  theme_minimal() + 
  scale_y_continuous(labels = scales::percent) +
  ylab("% of Consumers") +
  theme(
    #panel.grid.minor = element_blank(),
    text = element_text(family = "Avenir", size = 15),
    axis.text.x = element_text(angle = 45, vjust = 1, hjust=1),
    axis.title.x = element_blank()
  ) 

if (LOOP != "combined") {
  
  path_plot_out <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/plots_", 
                          SUFFIX_FITTED_, "/")
  
} 

if (LOOP == "combined") {
  
  path_plot_out <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                          "/plots_combined/")
}



ggsave(p_winners_losers,
       filename = paste0(path_plot_out, "winners_losers_xgb.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)



# Lambda table ------------------------------------------------------------

df_tb <- df_tradeoffs_metrics %>% 
  filter(model_type != "riskscore") %>% 
  select(model = model_type, thresh = thresh_type, lpr = loss_profit_ratio, fair = tpr_diff, profit) %>% 
  group_by(lpr) %>% 
  mutate(
    #profit = round((profit / max(profit)) * 100, 1),
    fair = round(fair * 100, 1)
  ) %>% 
  filter(lpr %in% c(1, 3, 5, 7, 9, 11)) %>% 
  mutate(
    thresh = str_remove(thresh, "tpr_"),
    model = str_remove(model, "istic"),
  ) %>% 
  pivot_wider(names_from = c(model, thresh), values_from = c(fair, profit)) %>% 
  mutate(
    #across(c(starts_with("fair_"), -fair_log_single), ~(. -fair_log_single)),
    across(c(starts_with("profit_")), ~round((. /profit_log_single) * 100, 1)),
    #across(c(starts_with("profit_"), -profit_log_single), ~(. -profit_log_single))
  ) %>% 
  pivot_longer(cols = c(starts_with("fair_"), starts_with("profit_"))) %>% 
  pivot_wider(names_from = lpr, values_from = value) %>% 
  separate(name, into = c("var", "model", "thresh"), remove = FALSE) %>% 
  mutate(
    var = factor(var, levels = c("fair", "profit"), labels = c("Fairness", "Profit")),
    model = factor(model, levels = c("log", "xgb"), labels = c("Log", "XGB")),
    thresh = factor(thresh, levels = c("single", "weak", "med", "strong"), labels = c("Blind", "Weak", "Medium", "Strong"))
  ) %>% 
  arrange(var, model, thresh) %>% 
  #select(-c(model, thresh, var)) %>% 
  print()

row_group_label_fonts <- list(
  list(bold = T, italic = T),
  list(bold = F, italic = F)
)

tb <- df_tb %>% 
  select(-name) %>% 
  rename(" " = model, Policy = thresh) %>% 
  kable(format = "latex", booktabs = T, escape = FALSE) %>% 
  add_header_above(c(" " = 3, "\\\\lambda" = dim(df_tb)[2] - 4), escape = FALSE) %>%
  #pack_rows(index = c("Fairness" = 8, "Profit" = 8), bold = T, italic = T) %>% 
  #pack_rows(index = c("Log" = 4, "XGB" = 4, "Log" = 4, "XGB" = 4), bold = F, italic = F) %>% 
  #collapse_rows(columns = 1, latex_hline = "major", valign = "middle") %>% 
  #column_spec(1, bold=T) %>%
  collapse_rows(1:2, latex_hline = 'custom', custom_latex_hline = 1:2,
                row_group_label_position = 'stack',
                row_group_label_fonts = row_group_label_fonts)

write_file(tb, paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                   "/plots_combined/lambda_tradeoffs.tex"))
  
